Skip to content

Conversation

@xmfan
Copy link
Member

@xmfan xmfan commented Nov 8, 2025

Stacked PRs:


Mark input tokens to routed experts as dynamic to avoid a recompile

This saves 1 recompile, and you can see the input tokens are dynamic from the first graph compiled:

class GraphModule(torch.nn.Module):
    def forward(...s77: "Sym(s77)", L_x_: "bf16[s77, 5120][5120, 1]cuda:0"...

I verified that this also fixes the AC recompile issue of: #1971. But I'm keeping torch._C._dynamo.eval_frame._set_lru_cache(False), as there could be other recompile reasons popping up.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 8, 2025
@jquesnelle
Copy link
Contributor

the fix for #1971 requires using PyTorch nightly as of a few days ago (_set_lru_cache just added) so this may be a preferrable fix to allow for using PyTorch stable (2.9)

Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change LGTM! nit: Can you explain more why we still need torch._C._dynamo.eval_frame._set_lru_cache(False) to fix #1971? If no specific reason, should we remove it?

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sgtm

@xmfan
Copy link
Member Author

xmfan commented Nov 17, 2025

This change LGTM! nit: Can you explain more why we still need torch._C._dynamo.eval_frame._set_lru_cache(False) to fix #1971? If no specific reason, should we remove it?

There are many issues like #1971 that can pop up, and they may not have good error messages. Keeping torch._C._dynamo.eval_frame._set_lru_cache(False) in the codebase will protect against all of them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants